{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Medical Concept Word Embedding Retrieval (OpenAI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pj20/miniconda3/envs/txgnn_env/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import csv\n",
    "from pyhealth.medcode.pretrained_embeddings.lm_emb.openai_retriever import embedding_retrieve as embedding_retriever\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import json\n",
    "import retrying"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "@retrying.retry(stop_max_attempt_number=5000)\n",
    "def retrieve_embedding(term):\n",
    "    return embedding_retriever(term)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Special Tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:01<00:00,  1.78it/s]\n"
     ]
    }
   ],
   "source": [
    "st_id2emb = {}\n",
    "special_tokens = [\"<pad>\", \"<unk>\"]\n",
    "\n",
    "for token in tqdm(special_tokens):\n",
    "    emb = embedding_retriever(term=token)\n",
    "    st_id2emb[token] = emb\n",
    "\n",
    "with open(f\"../resource/LM/special_tokens/special_tokens.json\", \"w\") as f:\n",
    "    json.dump(st_id2emb, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CCSCM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ccscm_id2name = {}\n",
    "with open('../resource/CCSCM.csv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    for line in lines[1:]:\n",
    "        line = line.strip().split(',')\n",
    "        ccscm_id2name[line[0]] = line[1].lower()\n",
    "\n",
    "ccscm_id2emb = {}\n",
    "for key in tqdm(ccscm_id2name.keys()):\n",
    "    emb = embedding_retriever(term=ccscm_id2name[key])\n",
    "    ccscm_id2emb[key] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/conditions/ccscm.json\", \"w\") as f:\n",
    "    json.dump(ccscm_id2emb, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CCSPROC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 231/231 [00:42<00:00,  5.38it/s]\n"
     ]
    }
   ],
   "source": [
    "ccsproc_id2name = {}\n",
    "with open('../resource/CCSPROC.csv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    for line in lines[1:]:\n",
    "        line = line.strip().split(',')\n",
    "        ccsproc_id2name[line[0]] = line[1].lower()\n",
    "\n",
    "ccsproc_id2emb = {}\n",
    "for key in tqdm(ccsproc_id2name.keys()):\n",
    "    emb = embedding_retriever(term=ccsproc_id2name[key])\n",
    "    ccsproc_id2emb[key] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/procedures/ccsproc.json\", \"w\") as f:\n",
    "    json.dump(ccsproc_id2emb, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ATC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6440/6440 [25:17<00:00,  4.24it/s]   \n"
     ]
    }
   ],
   "source": [
    "atc_id2name = {}\n",
    "with open(\"../resource/ATC.csv\", newline='') as csvfile:\n",
    "    reader = csv.DictReader(csvfile)\n",
    "    for row in reader:\n",
    "        # if row['level'] == '3.0':\n",
    "        atc_id2name[row['code']] = row['name'].lower()\n",
    "\n",
    "atc_id2emb = {}\n",
    "for key in tqdm(atc_id2name.keys()):\n",
    "    i = 0\n",
    "    emb = retrieve_embedding(term=atc_id2name[key])\n",
    "    atc_id2emb[key] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/drugs/atc.json\", \"w\") as f:\n",
    "    json.dump(atc_id2emb, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ICD9CM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 17736/17736 [52:29<00:00,  5.63it/s]  \n"
     ]
    }
   ],
   "source": [
    "from pyhealth.medcode import ICD9CM\n",
    "\n",
    "icd9cm_id2name = {}\n",
    "with open('../resource/ICD9CM.csv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    for line in lines[1:]:\n",
    "        line = line.strip().split(',')\n",
    "        icd9cm_id2name[line[0]] = line[2].lower()\n",
    "\n",
    "icd9cm_id2emb = {}\n",
    "for key in tqdm(icd9cm_id2name.keys()):\n",
    "    emb = retrieve_embedding(term=icd9cm_id2name[key])\n",
    "    icd9cm_id2emb[ICD9CM.standardize(key).replace('.', '')] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/conditions/icd9cm.json\", \"w\") as f:\n",
    "    json.dump(icd9cm_id2emb, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ICD9PROC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4670/4670 [44:28<00:00,  1.75it/s]  \n"
     ]
    }
   ],
   "source": [
    "from pyhealth.medcode import ICD9PROC\n",
    "\n",
    "icd9proc_id2name = {}\n",
    "with open('../resource/ICD9PROC.csv', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    for line in lines[1:]:\n",
    "        line = line.strip().split(',')\n",
    "        icd9proc_id2name[line[0]] = line[2].lower()\n",
    "\n",
    "icd9proc_id2emb = {}\n",
    "for key in tqdm(icd9proc_id2name.keys()):\n",
    "    emb = retrieve_embedding(term=icd9proc_id2name[key])\n",
    "    icd9proc_id2emb[ICD9PROC.standardize(key).replace('.', '')] = emb\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/procedures/icd9proc.json\", \"w\") as f:\n",
    "    json.dump(icd9proc_id2emb, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"../resource/embeddings/LM/conditions/icd9cm.json\", \"r\") as f:\n",
    "    icd9cm_id2emb = json.load(f)\n",
    "\n",
    "icd9cm_id2emb_new = {}\n",
    "\n",
    "for key, value in icd9cm_id2emb.items():\n",
    "    icd9cm_id2emb_new[key.replace('.', '')] = value\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/conditions/icd9cm.json\", \"w\") as f:\n",
    "    json.dump(icd9cm_id2emb_new, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/gpt3/procedures/icd9proc.json\", \"r\") as f:\n",
    "    icd9proc_id2emb = json.load(f)\n",
    "\n",
    "icd9proc_id2emb_new = {}\n",
    "\n",
    "for key, value in icd9proc_id2emb.items():\n",
    "    icd9proc_id2emb_new[key.replace('.', '')] = value\n",
    "    icd9proc_id2emb_new['3605'] = icd9proc_id2emb['0066']\n",
    "    icd9proc_id2emb_new['3602'] = icd9proc_id2emb['36']\n",
    "\n",
    "with open(f\"../resource/embeddings/LM/gpt3/procedures/icd9proc.json\", \"w\") as f:\n",
    "    json.dump(icd9proc_id2emb_new, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.16 ('txgnn_env')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "79cb95e61c4f960f4e102f21c45668d32cb5c494b237694c15d64b50342e6e99"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
